{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FhGuhbZ6M5tl"
      },
      "source": [
        "##### Copyright 2022 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "AwOEIRJC6Une"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EIdT9iu_Z4Rb"
      },
      "source": [
        "# Distributed training with Core APIs and DTensor"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bBIlTPscrIT9"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/core/distribution\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/core/distribution.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/core/distribution.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/core/distribution.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SjAxxRpBzVYg"
      },
      "source": [
        "## Introduction\n",
        "\n",
        "This notebook uses the [TensorFlow Core low-level APIs](https://www.tensorflow.org/guide/core) and [DTensor](https://www.tensorflow.org/guide/dtensor_overview) to demonstrate a data parallel distributed training example. Visit the [Core APIs overview](https://www.tensorflow.org/guide/core) to learn more about TensorFlow Core and its intended use cases. Refer to the [DTensor Overview](https://www.tensorflow.org/guide/dtensor_overview) guide and [Distributed Training with DTensors](https://www.tensorflow.org/tutorials/distribute/dtensor_ml_tutorial) tutorial to learn more about DTensor.\n",
        "\n",
        "This example uses the same model and optimizer shown in the [multilayer perceptrons](https://www.tensorflow.org/guide/core/mlp_core) tutorial. See this tutorial first to get comfortable with writing an end-to-end machine learning workflow with the Core APIs.\n",
        "\n",
        "Note: DTensor is still an experimental TensorFlow API which means that its features are available for testing, and it is intended for use in test environments only."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d_OFkG0dyWCp"
      },
      "source": [
        "## Overview of data parallel training with DTensor\n",
        "\n",
        "Before building an MLP that supports distribution, take a moment to explore the fundamentals of DTensor for data parallel training.\n",
        "\n",
        "DTensor allows you to run distributed training across devices to improve efficiency, reliability and scalability. DTensor distributes the program and tensors according to the sharding directives through a procedure called Single program, multiple data (SPMD) expansion. A variable of a `DTensor` aware layer is created as `dtensor.DVariable`, and the constructors of `DTensor` aware layer objects take additional `Layout` inputs in addition to the usual layer parameters.\n",
        "\n",
        "The main ideas for data parallel training are as follows:\n",
        " - Model variables are replicated on N devices each.\n",
        " - A global batch is split into N per-replica batches.\n",
        " - Each per-replica batch is trained on the replica device.\n",
        " - The gradient is reduced before weight up data is collectively performed on all replicas.\n",
        " - Data parallel training provides nearly linear speed with respect to the number of devices"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nchsZfwEVtVs"
      },
      "source": [
        "## Setup\n",
        "\n",
        "DTensor is part of TensorFlow 2.9.0 release."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "latuqlI_Yvoo"
      },
      "outputs": [],
      "source": [
        "#!pip install --quiet --upgrade --pre tensorflow"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1rRo8oNqZ-Rj"
      },
      "outputs": [],
      "source": [
        "import matplotlib\n",
        "from matplotlib import pyplot as plt\n",
        "# Preset Matplotlib figure sizes.\n",
        "matplotlib.rcParams['figure.figsize'] = [9, 6]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9xQKvCJ85kCQ"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "from tensorflow.experimental import dtensor\n",
        "print(tf.__version__)\n",
        "# Set random seed for reproducible results \n",
        "tf.random.set_seed(22)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vDH9-sy4sfPf"
      },
      "source": [
        "Configure 8 virtual CPUs for this experiment. DTensor can also be used with GPU or TPU devices. Given that this notebook uses virtual devices, the speedup gained from distributed training is not noticeable. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H2iM-6J4s2D6"
      },
      "outputs": [],
      "source": [
        "def configure_virtual_cpus(ncpu):\n",
        "  phy_devices = tf.config.list_physical_devices('CPU')\n",
        "  tf.config.set_logical_device_configuration(phy_devices[0], [\n",
        "        tf.config.LogicalDeviceConfiguration(),\n",
        "    ] * ncpu)\n",
        "\n",
        "configure_virtual_cpus(8)\n",
        "\n",
        "DEVICES = [f'CPU:{i}' for i in range(8)]\n",
        "devices = tf.config.list_logical_devices('CPU')\n",
        "device_names = [d.name for d in devices]\n",
        "device_names"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F_72b0LCNbjx"
      },
      "source": [
        "## The MNIST Dataset\n",
        "\n",
        "The dataset is available from [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/mnist). Split the data into training and testing sets. Only use 5000 examples for training and testing to save time."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8h4fV_JCfPIX"
      },
      "outputs": [],
      "source": [
        "train_data, test_data = tfds.load(\"mnist\", split=['train[:5000]', 'test[:5000]'], batch_size=128, as_supervised=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "twkJ35YB6tSi"
      },
      "source": [
        "### Preprocessing the data\n",
        "\n",
        "Preprocess the data by reshaping it to be 2-dimensional and by rescaling it to fit into the unit interval, [0,1]."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6Cmjhg0xCqbz"
      },
      "outputs": [],
      "source": [
        "def preprocess(x, y):\n",
        "  # Reshaping the data\n",
        "  x = tf.reshape(x, shape=[-1, 784])\n",
        "  # Rescaling the data\n",
        "  x = x/255\n",
        "  return x, y\n",
        "\n",
        "train_data, test_data = train_data.map(preprocess), test_data.map(preprocess)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6o3CrycBXA2s"
      },
      "source": [
        "## Build the MLP \n",
        "\n",
        "Build an MLP model with DTensor aware layers."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OHW6Yvg2yS6H"
      },
      "source": [
        "### The dense layer\n",
        "\n",
        "Start by creating a dense layer module that supports DTensor. The `dtensor.call_with_layout` function can be used to call a function that takes in a DTensor input and produces a DTensor output. This is useful for initializing a DTensor variable, `dtensor.DVariable`, with a TensorFlow supported function."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IM0yJos25FG5"
      },
      "outputs": [],
      "source": [
        "class DenseLayer(tf.Module):\n",
        "\n",
        "  def __init__(self, in_dim, out_dim, weight_layout, activation=tf.identity):\n",
        "    super().__init__()\n",
        "    # Initialize dimensions and the activation function\n",
        "    self.in_dim, self.out_dim = in_dim, out_dim\n",
        "    self.activation = activation\n",
        "\n",
        "    # Initialize the DTensor weights using the Xavier scheme\n",
        "    uniform_initializer = tf.function(tf.random.stateless_uniform)\n",
        "    xavier_lim = tf.sqrt(6.)/tf.sqrt(tf.cast(self.in_dim + self.out_dim, tf.float32))\n",
        "    self.w = dtensor.DVariable(\n",
        "      dtensor.call_with_layout(\n",
        "          uniform_initializer, weight_layout,\n",
        "          shape=(self.in_dim, self.out_dim), seed=(22, 23),\n",
        "          minval=-xavier_lim, maxval=xavier_lim))\n",
        "        \n",
        "    # Initialize the bias with the zeros\n",
        "    bias_layout = weight_layout.delete([0])\n",
        "    self.b = dtensor.DVariable(\n",
        "      dtensor.call_with_layout(tf.zeros, bias_layout, shape=[out_dim]))\n",
        "\n",
        "  def __call__(self, x):\n",
        "    # Compute the forward pass\n",
        "    z = tf.add(tf.matmul(x, self.w), self.b)\n",
        "    return self.activation(z)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X-7MzpjgyHg6"
      },
      "source": [
        "### The MLP sequential model\n",
        "\n",
        "Now create an MLP module that executes the dense layers sequentially."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6XisRWiCyHAb"
      },
      "outputs": [],
      "source": [
        "class MLP(tf.Module):\n",
        "\n",
        "  def __init__(self, layers):\n",
        "    self.layers = layers\n",
        "   \n",
        "  def __call__(self, x, preds=False): \n",
        "    # Execute the model's layers sequentially\n",
        "    for layer in self.layers:\n",
        "      x = layer(x)\n",
        "    return x"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r5HZJ0kv-V3v"
      },
      "source": [
        "Performing \"data-parallel\" training with DTensor is equivalent to `tf.distribute.MirroredStrategy`. To do this each device will run the same model on a shard of the data batch. So you'll need the following:\n",
        "\n",
        "* A `dtensor.Mesh` with a single `\"batch\"` dimension\n",
        "* A `dtensor.Layout` for all the weights that replicates them across the mesh (using `dtensor.UNSHARDED` for each axis)\n",
        "* A `dtensor.Layout` for the data that splits the batch dimension across the mesh\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "Create a DTensor mesh that consists of a single batch dimension, where each device becomes a replica that receives a shard from the global batch. Use this mesh to instantiate an MLP mode with the following architecture:\n",
        "\n",
        "Forward Pass: ReLU(784 x 700) x ReLU(700 x 500) x Softmax(500 x 10)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VmlACuki3oPi"
      },
      "outputs": [],
      "source": [
        "mesh = dtensor.create_mesh([(\"batch\", 8)], devices=DEVICES)\n",
        "weight_layout = dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh)\n",
        "\n",
        "input_size = 784\n",
        "hidden_layer_1_size = 700\n",
        "hidden_layer_2_size = 500\n",
        "hidden_layer_2_size = 10\n",
        "\n",
        "mlp_model = MLP([\n",
        "    DenseLayer(in_dim=input_size, out_dim=hidden_layer_1_size, \n",
        "               weight_layout=weight_layout,\n",
        "               activation=tf.nn.relu),\n",
        "    DenseLayer(in_dim=hidden_layer_1_size , out_dim=hidden_layer_2_size,\n",
        "               weight_layout=weight_layout,\n",
        "               activation=tf.nn.relu),\n",
        "    DenseLayer(in_dim=hidden_layer_2_size, out_dim=hidden_layer_2_size, \n",
        "               weight_layout=weight_layout)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tyBATDoRmDkg"
      },
      "source": [
        "### Training metrics\n",
        "\n",
        "Use the cross-entropy loss function and accuracy metric for training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rskOYA7FVCwg"
      },
      "outputs": [],
      "source": [
        "def cross_entropy_loss(y_pred, y):\n",
        "  # Compute cross entropy loss with a sparse operation\n",
        "  sparse_ce = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=y_pred)\n",
        "  return tf.reduce_mean(sparse_ce)\n",
        "\n",
        "def accuracy(y_pred, y):\n",
        "  # Compute accuracy after extracting class predictions\n",
        "  class_preds = tf.argmax(y_pred, axis=1)\n",
        "  is_equal = tf.equal(y, class_preds)\n",
        "  return tf.reduce_mean(tf.cast(is_equal, tf.float32))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JSiNRhTOnKZr"
      },
      "source": [
        "### Optimizer\n",
        "\n",
        "Using an optimizer can result in significantly faster convergence compared to standard gradient descent. The Adam optimizer is implemented below and has been configured to be compatible with DTensor. In order to use Keras optimizers with DTensor, refer to the experimental`tf.keras.dtensor.experimental.optimizers` module."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-9kIAI_lfXDS"
      },
      "outputs": [],
      "source": [
        "class Adam(tf.Module):\n",
        "\n",
        "    def __init__(self, model_vars, learning_rate=1e-3, beta_1=0.9, beta_2=0.999, ep=1e-7):\n",
        "      # Initialize optimizer parameters and variable slots\n",
        "      self.model_vars = model_vars\n",
        "      self.beta_1 = beta_1\n",
        "      self.beta_2 = beta_2\n",
        "      self.learning_rate = learning_rate\n",
        "      self.ep = ep\n",
        "      self.t = 1.\n",
        "      self.v_dvar, self.s_dvar = [], []\n",
        "      # Initialize optimizer variable slots\n",
        "      for var in model_vars:\n",
        "        v = dtensor.DVariable(dtensor.call_with_layout(tf.zeros, var.layout, shape=var.shape))\n",
        "        s = dtensor.DVariable(dtensor.call_with_layout(tf.zeros, var.layout, shape=var.shape))\n",
        "        self.v_dvar.append(v)\n",
        "        self.s_dvar.append(s)\n",
        "\n",
        "    def apply_gradients(self, grads):\n",
        "      # Update the model variables given their gradients\n",
        "      for i, (d_var, var) in enumerate(zip(grads, self.model_vars)):\n",
        "        self.v_dvar[i].assign(self.beta_1*self.v_dvar[i] + (1-self.beta_1)*d_var)\n",
        "        self.s_dvar[i].assign(self.beta_2*self.s_dvar[i] + (1-self.beta_2)*tf.square(d_var))\n",
        "        v_dvar_bc = self.v_dvar[i]/(1-(self.beta_1**self.t))\n",
        "        s_dvar_bc = self.s_dvar[i]/(1-(self.beta_2**self.t))\n",
        "        var.assign_sub(self.learning_rate*(v_dvar_bc/(tf.sqrt(s_dvar_bc) + self.ep)))\n",
        "      self.t += 1.\n",
        "      return "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w54b7GtLfn1j"
      },
      "source": [
        "### Data packing\n",
        "\n",
        "Start by writing a helper function for transferring data to the device. This function should use `dtensor.pack` to send (and only send) the shard of the global batch that is intended for a replica to the device backing the replica. For simplicity, assume a single-client application.\n",
        "\n",
        "Next, write a function that uses this helper function to pack the training data batches into DTensors sharded along the batch (first) axis. This ensures that DTensor evenly distributes the training data to the 'batch' mesh dimension. Note that in DTensor, the batch size always refers to the global batch size; therefore, the batch size should be chosen such that it can be divided evenly by the size of the batch mesh dimension. Additional DTensor APIs to simplify `tf.data` integration are planned, so please stay tuned."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3Rx82djZ6ITm"
      },
      "outputs": [],
      "source": [
        "def repack_local_tensor(x, layout):\n",
        "  # Repacks a local Tensor-like to a DTensor with layout\n",
        "  # This function assumes a single-client application\n",
        "  x = tf.convert_to_tensor(x)\n",
        "  sharded_dims = []\n",
        "\n",
        "  # For every sharded dimension, use tf.split to split the along the dimension.\n",
        "  # The result is a nested list of split-tensors in queue[0].\n",
        "  queue = [x]\n",
        "  for axis, dim in enumerate(layout.sharding_specs):\n",
        "    if dim == dtensor.UNSHARDED:\n",
        "      continue\n",
        "    num_splits = layout.shape[axis]\n",
        "    queue = tf.nest.map_structure(lambda x: tf.split(x, num_splits, axis=axis), queue)\n",
        "    sharded_dims.append(dim)\n",
        "\n",
        "  # Now you can build the list of component tensors by looking up the location in\n",
        "  # the nested list of split-tensors created in queue[0].\n",
        "  components = []\n",
        "  for locations in layout.mesh.local_device_locations():\n",
        "    t = queue[0]\n",
        "    for dim in sharded_dims:\n",
        "      split_index = locations[dim]  # Only valid on single-client mesh.\n",
        "      t = t[split_index]\n",
        "    components.append(t)\n",
        "\n",
        "  return dtensor.pack(components, layout)\n",
        "\n",
        "def repack_batch(x, y, mesh):\n",
        "  # Pack training data batches into DTensors along the batch axis\n",
        "  x = repack_local_tensor(x, layout=dtensor.Layout(['batch', dtensor.UNSHARDED], mesh))\n",
        "  y = repack_local_tensor(y, layout=dtensor.Layout(['batch'], mesh))\n",
        "  return x, y"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "osEK3rqpYfKd"
      },
      "source": [
        "### Training\n",
        "\n",
        "Write a traceable function that executes a single training step given a batch of data. This function does not require any special DTensor annotations. Also write a function that executes a test step and returns the appropriate performance metrics."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZICEsDGuSbDD"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def train_step(model, x_batch, y_batch, loss, metric, optimizer):\n",
        "  # Execute a single training step\n",
        "  with tf.GradientTape() as tape:\n",
        "    y_pred = model(x_batch)\n",
        "    batch_loss = loss(y_pred, y_batch)\n",
        "  # Compute gradients and update the model's parameters\n",
        "  grads = tape.gradient(batch_loss, model.trainable_variables)\n",
        "  optimizer.apply_gradients(grads)\n",
        "  # Return batch loss and accuracy\n",
        "  batch_acc = metric(y_pred, y_batch)\n",
        "  return batch_loss, batch_acc\n",
        "\n",
        "@tf.function\n",
        "def test_step(model, x_batch, y_batch, loss, metric):\n",
        "  # Execute a single testing step\n",
        "  y_pred = model(x_batch)\n",
        "  batch_loss = loss(y_pred, y_batch)\n",
        "  batch_acc = metric(y_pred, y_batch)\n",
        "  return batch_loss, batch_acc"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RjIDVTwwX-Mr"
      },
      "source": [
        "Now, train the MLP model for 3 epochs with a batch size of 128."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oC85kuZgmh3q"
      },
      "outputs": [],
      "source": [
        "# Initialize the training loop parameters and structures\n",
        "epochs = 3\n",
        "batch_size = 128\n",
        "train_losses, test_losses = [], []\n",
        "train_accs, test_accs = [], []\n",
        "optimizer = Adam(mlp_model.trainable_variables)\n",
        "\n",
        "# Format training loop\n",
        "for epoch in range(epochs):\n",
        "  batch_losses_train, batch_accs_train = [], []\n",
        "  batch_losses_test, batch_accs_test = [], []\n",
        "\n",
        "  # Iterate through training data\n",
        "  for x_batch, y_batch in train_data:\n",
        "    x_batch, y_batch = repack_batch(x_batch, y_batch, mesh)\n",
        "    batch_loss, batch_acc = train_step(mlp_model, x_batch, y_batch, cross_entropy_loss, accuracy, optimizer)\n",
        "   # Keep track of batch-level training performance\n",
        "    batch_losses_train.append(batch_loss)\n",
        "    batch_accs_train.append(batch_acc)\n",
        "\n",
        "  # Iterate through testing data\n",
        "  for x_batch, y_batch in test_data:\n",
        "    x_batch, y_batch = repack_batch(x_batch, y_batch, mesh)\n",
        "    batch_loss, batch_acc = test_step(mlp_model, x_batch, y_batch, cross_entropy_loss, accuracy)\n",
        "    # Keep track of batch-level testing\n",
        "    batch_losses_test.append(batch_loss)\n",
        "    batch_accs_test.append(batch_acc)\n",
        "\n",
        "# Keep track of epoch-level model performance\n",
        "  train_loss, train_acc = tf.reduce_mean(batch_losses_train), tf.reduce_mean(batch_accs_train)\n",
        "  test_loss, test_acc = tf.reduce_mean(batch_losses_test), tf.reduce_mean(batch_accs_test)\n",
        "  train_losses.append(train_loss)\n",
        "  train_accs.append(train_acc)\n",
        "  test_losses.append(test_loss)\n",
        "  test_accs.append(test_acc)\n",
        "  print(f\"Epoch: {epoch}\")\n",
        "  print(f\"Training loss: {train_loss.numpy():.3f}, Training accuracy: {train_acc.numpy():.3f}\")\n",
        "  print(f\"Testing loss: {test_loss.numpy():.3f}, Testing accuracy: {test_acc.numpy():.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j_RVmt43G12R"
      },
      "source": [
        "### Performance evaluation\n",
        "\n",
        "Start by writing a plotting function to visualize the model's loss and accuracy during training. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VXTCYVtNDjAM"
      },
      "outputs": [],
      "source": [
        "def plot_metrics(train_metric, test_metric, metric_type):\n",
        "  # Visualize metrics vs training Epochs\n",
        "  plt.figure()\n",
        "  plt.plot(range(len(train_metric)), train_metric, label = f\"Training {metric_type}\")\n",
        "  plt.plot(range(len(test_metric)), test_metric, label = f\"Testing {metric_type}\")\n",
        "  plt.xlabel(\"Epochs\")\n",
        "  plt.ylabel(metric_type)\n",
        "  plt.legend()\n",
        "  plt.title(f\"{metric_type} vs Training Epochs\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "407qok7q2JIO"
      },
      "outputs": [],
      "source": [
        "plot_metrics(train_losses, test_losses, \"Cross entropy loss\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8H_TgxV92NfX"
      },
      "outputs": [],
      "source": [
        "plot_metrics(train_accs, test_accs, \"Accuracy\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DHO_u-3w4YRF"
      },
      "source": [
        "## Saving your model\n",
        "\n",
        "The integration of `tf.saved_model` and DTensor is still under development. As of TensorFlow 2.9.0, tf.saved_model only accepts DTensor models with fully replicated variables. As a workaround, you can convert a DTensor model to a fully replicated one by reloading a checkpoint. However, after a model is saved, all DTensor annotations are lost and the saved signatures can only be used with regular Tensors. This tutorial will be updated to showcase the integration once it is solidified.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VFLfEH4ManbW"
      },
      "source": [
        "## Conclusion\n",
        "\n",
        "This notebook provided an overview of distributed training with DTensor and the TensorFlow Core APIs. Here are a few more tips that may help:\n",
        "\n",
        "- The [TensorFlow Core APIs](https://www.tensorflow.org/guide/core) can be used to build highly-configurable machine learning workflows with support for distributed training.\n",
        "- The [DTensor concepts](https://www.tensorflow.org/guide/dtensor_overview) guide and [Distributed training with DTensors](https://www.tensorflow.org/tutorials/distribute/dtensor_ml_tutorial) tutorial contain the most up-to-date information about DTensor and its integrations.\n",
        "\n",
        "For more examples of using the TensorFlow Core APIs, check out the [guide](https://www.tensorflow.org/guide/core). If you want to learn more about loading and preparing data, see the tutorials on [image data loading](https://www.tensorflow.org/tutorials/load_data/images) or [CSV data loading](https://www.tensorflow.org/tutorials/load_data/csv)."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "FhGuhbZ6M5tl"
      ],
      "name": "distribution.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
